from dataclasses import dataclass, field
from typing import Literal, Callable, Any, Optional, Type, Tuple
from torch.nn import Sequential
import os

from .mppi.core import (
    CostBaseCfg, NNCostCfg, MPPICfg
)
from .mppi.core.cost.utils import fcnn_factory

from .cost import TDCost, GAIfOCost, CostToGo

@dataclass
class GAIfOCostCfg(NNCostCfg):

    state_dim: int = 12

    model_kwargs: dict = field(default_factory=lambda: {
        "input_dim": 12,
        "hidden_dims": [32, 32],
        "activation": "lrelu",
        "use_spectral_norm": True,
    })

    # feat_dim: int = None # Backwards compatibility

    class_type: Type[GAIfOCost] = GAIfOCost

    model_factory: Callable[[Any], Sequential] = fcnn_factory

    clip_costs: Optional[Tuple[float, float]] = None

@dataclass
class CostToGoCfg(NNCostCfg):

    class_type: Type[CostToGo] = CostToGo

    model_factory: Callable[[Any], Sequential] = fcnn_factory

    clip_costs: Optional[Tuple[float, float]] = None

@dataclass
class TDCostCfg(CostBaseCfg):

    single_step_class_cfg: CostBaseCfg = None
    '''Config for the single step cost function'''

    terminal_state_class_cfg: CostBaseCfg = None
    '''Config for the terminal state cost function'''

    ss_coeff: float = None
    '''single step cost coefficient'''

    ts_coeff: float = None
    '''terminal state cost coefficient for cost value function'''

    class_type: Type[TDCost] = TDCost

@dataclass
class MPAILTDCostCfg(TDCostCfg):

    single_step_class_cfg: GAIfOCostCfg = GAIfOCostCfg()
    '''Config for the single step cost function'''

    terminal_state_class_cfg: CostToGoCfg = CostToGoCfg()
    '''Config for the terminal state cost function'''

    ss_coeff: float = 1.0

    ts_coeff: float = 1.0

    risk: Optional[float] = None
    '''Discount factor for cost computation'''

    class_type: Type[TDCost] = TDCost
    '''Type of cost function'''

    feature_inds: Optional[list] = None
    '''Feature indices to apply costing to. If None, applies to all features'''

@dataclass
class MPAILPolicyCfg(MPPICfg):

    action_dist: Literal["normal","categorical"] = None
    '''Distribution of action selector'''

    action_dist_params: dict = None
    '''Parameters for action distribution'''

    temp_lr: float = None
    '''Learning rate for temperature'''

    min_temp: float = None
    '''Minimum temperature for action distribution'''

    cost_cfg: MPAILTDCostCfg = None
    '''Configuration for sampling module'''

@dataclass
class ValueLearnerCfg:

    opt: str = None
    '''Optimizer type'''

    opt_params: dict = None
    '''Optimizer parameters'''

    use_clipped_value_loss: bool = None
    '''Whether the value loss is clipped'''

    value_clip: float = None
    '''If value loss is clipped, it's clipped to (-value_clip, value_clip)'''

    gamma: float = None
    '''Discount factor'''

    lam: float = None
    '''GAE lambda'''

    max_grad_norm: float = None
    '''Clips the gradient norm of the value function parameters to this value'''

@dataclass
class DiscLearnerCfg:

    opt: str = None
    '''Optimizer type'''

    opt_params: dict = None
    '''Optimizer parameters'''

    reg_coeff: float = None
    '''Weight regularization coefficient for the discriminator'''

@dataclass
class MPAILLearnerCfg:

    num_mini_batches: int = None
    '''Number of mini batches to train on'''

    num_learning_epochs: int = None
    '''Number of epochs to train on'''

    train_disc_every: int = None
    '''Trains discriminator every n iterations'''

    #
    # Disciminator
    #

    disc_learner_cfg : DiscLearnerCfg = None
    '''Configuration for discriminator learning algorithm'''

    #
    # Value
    #

    value_learner_cfg : ValueLearnerCfg = None
    '''Configuration for value function approximating discriminator logit returns'''

    #
    # Policy
    #

    policy_cfg: MPAILPolicyCfg = None
    '''Configuration for MPPI module'''

@dataclass
class MPAILRunnerCfg:
    '''Manages stepping MPAIL and logging stats and training details'''

    learner_cfg: MPAILLearnerCfg = None
    '''Configuration for MPPI'''

    num_steps_per_env: int = None
    '''Number of steps per environment'''

    num_learning_iterations: int = None
    '''Number of learning iterations'''

    path_to_demonstrations: str = None
    '''Number of steps per environment'''

    seed: int = None
    '''Random seed'''

    logger: Optional[Literal["wandb"]] = None
    '''Logger type'''

    enable_rl_value: bool = False
    '''Learn value function using rewards from rl environment'''

    enable_imitation_value: bool = False
    '''Learn value function using rewards from imitation'''

    save_separate_policy_cfg: bool = False
    '''Whether to save the policy config (.pkl) separately in another file'''
